spikeinterface motion estimation
motion estimation in spikeinterface¶
In 2021,the SpikeInterface project has started to implemented sortingcomponents, a modular module for spike sorting steps.
Here is an overview or our progress integrating motion (aka drift) estimation and correction.
This notebook will be based on the open dataset from Nick Steinmetz published in 2021 "Imposed motion datasets" from Steinmetz et al. Science 2021 https://figshare.com/articles/dataset/_Imposed_motion_datasets_from_Steinmetz_et_al_Science_2021/14024495
The motion estimation is done in several modular steps:
- detect peaks
- localize peaks:
- "center of of mass"
- "monopolar_triangulation" by Julien Boussard and Erdem Varol https://openreview.net/pdf?id=ohfi44BZPC4
- estimation motion:
- rigid or non rigid
- "decentralize" by Erdem Varol and Julien Boussard DOI : 10.1109/ICASSP39728.2021.9414145
- "motion cloud" by Julien Boussard (not implemented yet)
Here we will show this chain:
- detect peak > localize peaks with "monopolar_triangulation" > estimation motion "decentralize"
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
from pathlib import Path
import spikeinterface.full as si
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)
from probeinterface.plotting import plot_probe
from spikeinterface.sortingcomponents import detect_peaks
from spikeinterface.sortingcomponents import localize_peaks
In [3]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')
dataset_folder = base_folder / 'dataset1/NP1'
preprocess_folder = base_folder / 'dataset1_NP1_preprocessed'
peak_folder = base_folder / 'dataset1_NP1_peaks'
peak_folder.mkdir(exist_ok=True)
In [4]:
# global kwargs for parallel computing
job_kwargs = dict(
n_jobs=40,
chunk_memory='10M',
progress_bar=True,
)
In [5]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[5]:
In [11]:
fig, ax = plt.subplots()
plot_probe(rec.get_probe(), ax=ax)
ax.set_ylim(-150, 200)
Out[11]:
preprocess¶
This take 4 min for 30min of signals
In [7]:
rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
Out[7]:
In [5]:
# load back
rec_preprocessed = si.load_extractor(preprocess_folder)
rec_preprocessed
Out[5]:
In [12]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[12]:
estimate noise¶
In [14]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=np.arange(0,10, 1))
ax.set_title('noise across channel')
Out[14]:
detect peaks¶
This take 1min30s
In [15]:
from spikeinterface.sortingcomponents import detect_peaks
In [16]:
peaks = detect_peaks(
rec_preprocessed,
method='locally_exclusive',
local_radius_um=100,
peak_sign='neg',
detect_threshold=5,
n_shifts=5,
noise_levels=noise_levels,
**job_kwargs,
)
np.save(peak_folder / 'peaks.npy', peaks)
In [8]:
# load back
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
In [18]:
from spikeinterface.sortingcomponents import localize_peaks
In [19]:
peak_locations = localize_peaks(
rec_preprocessed,
peaks,
ms_before=0.3,
ms_after=0.6,
method='center_of_mass',
method_kwargs={'local_radius_um': 100.},
**job_kwargs,
)
np.save(peak_folder / 'peak_locations_center_of_mass.npy', peak_locations)
print(peak_locations.shape)
In [20]:
peak_locations = localize_peaks(
rec_preprocessed,
peaks,
ms_before=0.3,
ms_after=0.6,
method='monopolar_triangulation',
method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000.},
**job_kwargs,
)
np.save(peak_folder / 'peak_locations_monopolar_triangulation.npy', peak_locations)
print(peak_locations.shape)
In [6]:
# load back
# peak_locations = np.load(peak_folder / 'peak_locations_center_of_mass.npy')
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation.npy')
print(peak_locations)
plot peak on probe¶
In [16]:
probe = rec_preprocessed.get_probe()
fig, ax = plt.subplots(figsize=(15, 10))
plot_probe(probe, ax=ax)
ax.scatter(peak_locations['x'], peak_locations['y'], color='k', s=1, alpha=0.002)
# ax.set_ylim(2400, 2900)
ax.set_ylim(1500, 2500)
Out[16]:
plot peak depth vs time¶
In [11]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
Out[11]:
motion estimate : rigid with decentralized¶
In [17]:
from spikeinterface.sortingcomponents import (
estimate_motion,
make_motion_histogram,
compute_pairwise_displacement,
compute_global_displacement
)
In [18]:
bin_um = 2
bin_duration_s=5.
motion_histogram, temporal_bins, spatial_bins = make_motion_histogram(
rec_preprocessed,
peaks,
peak_locations=peak_locations,
bin_um=bin_um,
bin_duration_s=bin_duration_s,
direction='y',
weight_with_amplitude=False,
)
print(motion_histogram.shape, temporal_bins.size, spatial_bins.size)
In [22]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1])
im = ax.imshow(
motion_histogram.T,
interpolation='nearest',
origin='lower',
aspect='auto',
extent=extent,
)
im.set_clim(0, 15)
ax.set_ylim(1300, 2500)
ax.set_xlabel('time[s]')
ax.set_ylabel('depth[um]')
Out[22]:
pariwise displacement from the motion histogram¶
In [23]:
pairwise_displacement = compute_pairwise_displacement(motion_histogram, bin_um, method='conv2d', )
np.save(peak_folder / 'pairwise_displacement_conv2d.npy', pairwise_displacement)
In [24]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], temporal_bins[0], temporal_bins[-1])
# extent = None
im = ax.imshow(
pairwise_displacement,
interpolation='nearest',
cmap='PiYG',
origin='lower',
aspect='auto',
extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)
Out[24]:
estimate motion (rigid) from the pairwise displacement¶
In [25]:
motion = compute_global_displacement(pairwise_displacement)
In [26]:
fig, ax = plt.subplots()
ax.plot(temporal_bins[:-1], motion)
Out[26]:
motion estimation with one unique funtion¶
Internally estimate_motion() does:
- make_motion_histogram()
- compute_pairwise_displacement()
- compute_global_displacement()
In [27]:
motion, temporal_bins, spatial_bins = estimate_motion(
rec_preprocessed,
peaks,
peak_locations=peak_locations,
direction='y',
bin_duration_s=5.,
bin_um=10.,
method='decentralized_registration',
method_kwargs={},
non_rigid_kwargs=None,
progress_bar=True,
verbose=True,
)
In [30]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
ax.plot(temporal_bins[:-1], motion + 2000, color='r')
ax.set_xlabel('times[s]')
ax.set_ylabel('motion [um]')
Out[30]:
motion estimation non rigid¶
In [31]:
motion, temporal_bins, spatial_bins = estimate_motion(
rec_preprocessed,
peaks,
peak_locations=peak_locations,
direction='y',
bin_duration_s=5.,
bin_um=10.,
method='decentralized_registration',
method_kwargs={},
non_rigid_kwargs=dict(bin_step_um=200),
progress_bar=True,
verbose=True,
)
print(motion.shape)
print(temporal_bins.shape)
In [32]:
fs = rec_preprocessed.get_sampling_frequency()
fig, ax = plt.subplots()
ax.scatter(peaks['sample_ind'] / fs, peak_locations['y'], color='k', s=0.1, alpha=0.05)
for i, s_bins in enumerate(spatial_bins):
# several motion vector
ax.plot(temporal_bins[:-1], motion[:, i] + spatial_bins[i], color='r')
ax.set_ylim(1300, 2500)
ax.set_xlabel('times[s]')
ax.set_ylabel('motion [um]')
Out[32]:
In [ ]: